import cdtools
from cdtools.tools import plotting as p
from matplotlib import pyplot as plt
from scipy.io import loadmat, savemat
import numpy as np
import torch as t
from helper_functions import synthesize_magnetic_results, get_circular_lineout, calc_pcfrc, calc_pcmse, calc_sqrt_fidelity
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator, LogLocator)

# This gives the plots a clear background when saved to PDF
plt.rcParams.update({
    "figure.facecolor":  (1.0, 1.0, 1.0, 0.0),  # clear
    "axes.facecolor":    (1.0, 1.0, 1.0, 1.0),  # white
    "savefig.facecolor": (1.0, 1.0, 1.0, 0.0),  # clear
})

# If true, it saves out the synthesized (half/half/full) reconstructions.
save_synthesized = True

# If true, it saves the relevant figures. If false, it simply plots them.
save_figures = True
if save_figures:
    from pathlib import Path
    Path('figs').mkdir(exist_ok=True)


#
#
# First, we load and process the left-hand circularly polarized results
#
#

    
# This is the ptychography dataset, it used single shots for each exposure.
filename = 'magnetic_ptycho_rcp.cxi'

dataset = cdtools.datasets.Ptycho2DDataset.from_cxi('data/' + filename,
                                                    cut_zeros=False)

# This helps us measure the signal level based on the detector values
electronsperphoton = 60 / 3.6
ADUperelectron =2**16 / 200000 # from the MTE-2048B datasheet
ADUperphoton = ADUperelectron * electronsperphoton

# In one preliminary computational experiment, we compared the quality of
# super-resolution ptychography reconstructions at different padding
# levels. We defined the approximate center of the Siements star in this way
# to enable that comparison, although for the final analysis this is not
# necessary.

# We load all three reconstructions
half_1 = loadmat('results/single_shot_magnetic_ptycho_rcp_half_1.mat')
half_2 = loadmat('results/single_shot_magnetic_ptycho_rcp_half_2.mat')
full = loadmat('results/single_shot_magnetic_ptycho_rcp_full.mat')

# This corrects for a change in CDTools between when the reconstructions used
# for the paper were run, and the latest resubmission
if 'basis' not in half_1:
    half_1['basis'] = half_1['obj_basis']
    half_1['translation'] = half_1['translations']
    half_2['basis'] = half_2['obj_basis']
    half_2['translation'] = half_2['translations']
    full['basis'] = full['obj_basis']
    full['translation'] = full['translations']


# We define a large window which crops out most of surrounding noise
# We do this because the noise at the very edge is large enough to
# cause artifacts when it is subpixel-shifted
window = np.s_[300:1100,350:1150]

half_1['obj'] = half_1['obj'][window]
half_2['obj'] = half_2['obj'][window]
full['obj'] = full['obj'][window]

#window = np.s_[275:525,250:500]
window = np.s_[340:500,300:460]


#p.plot_amplitude(full['obj'])
#p.plot_amplitude(full['obj'][window])
#plt.show()



# This does the comparison between the two 50% data reconstructions and the
# full data result, correcting for phase ramps, shifts, and an overall
# amplitude exponent which is left unconstrained by the introduction of
# a per-exposure weighting factor.
results = synthesize_magnetic_results(half_1, half_2, full, window)

if save_synthesized:
    savemat(('results/' + '.'.join(filename.split('.')[:-1])
             + '_synthesized.mat'), results)


#
# Plotting Section
#

# We get shorthands for some commonly used variables
basis = results['basis']
freqs = results['frc_freqs']
ssnr = results['ssnr']
frc = results['frc']
threshold = results['frc_threshold']
psize = np.abs(basis[0,1])
psize_A = int(np.round(psize*1e10)) # pixel size in angstroms, 3-4 sig figs
print('Pixel size in Angstrom:', psize_A)

intensities = np.sum(np.abs(results['probe_full'])**2, axis=(1,2))
intensities = intensities * 100 / np.sum(intensities)
plt.figure(figsize=(3.75,2.5))
plt.bar(list(range(len(intensities))), intensities)
plt.xlabel('Mode Number')
plt.ylabel('Power Fraction (%)')
plt.tight_layout()


def savefig_plus_im(fname):
    """Saves the current figure, and the displayed im in full resolution.

    The figure gets saved as a pdf, but with imshow images, that pdf will
    have a sampled/rescaled version of the image. This function saves out
    a full resolution png at the same time.
    """
    plt.savefig(fname + '.pdf')
    im = plt.gca().get_images()[0]
    plt.imsave(fname + '_%dApix.png' % psize_A, im.get_array(), cmap=im.cmap)


# First, we plot some raw data
intensities = t.sum(dataset.patterns, dim=(1,2)).numpy() / ADUperphoton
mean_signal = np.mean(intensities)
std_signal = np.std(intensities)

print('Mean ptycho signal Level:', int(mean_signal), 'photons')
print('STD ptycho signal Level:', int(std_signal), 'photons')

dataset.patterns = t.clamp(dataset.patterns, min=0)

p.plot_real(np.log10(dataset.patterns[0].numpy() + 1))
plt.title('Log base 10 of ptycho detector data + 1 (ADU)')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_data_adu')

p.plot_real(np.log10(dataset.patterns[0].numpy() / ADUperphoton  + 1))
plt.title('Log base 10 of ptycho detector data + 1 (Photons)')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_data_photons')


# Second, we plot the ptychography images
to_plot = np.abs(results['obj_full'])
to_plot = np.clip(to_plot, a_min=0, a_max=0.6) # hard-set the scale
a = p.plot_amplitude(to_plot, basis=basis)
plt.title('Amplitude of the full ptychography result (full window)')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_obj_amplitude_full_window')

to_plot = np.abs(results['obj_full'])
to_plot = to_plot / np.mean(np.abs(results['obj_full'][window]))
to_plot = np.clip(to_plot, a_min=0.85, a_max=1.15)
a = p.plot_amplitude(to_plot, basis=basis)
plt.title('Amplitude of the full ptychography result (full window)')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_obj_amplitude_set_colorscale_full_window')

to_plot = np.angle(results['obj_full'])
to_plot = np.exp(1j*np.clip(to_plot, a_min=-0.1, a_max=0.1))
a = p.plot_phase(to_plot, basis=basis)
plt.title('Amplitude of the full ptychography result (full window)')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_obj_phase_set_colorscale_full_window')

# Second, we plot the ptychography images 
to_plot = np.abs(results['obj_full'][window])
to_plot = to_plot / np.mean(np.abs(to_plot))
to_plot = np.clip(to_plot, a_min=0.85, a_max=1.15)
a = p.plot_amplitude(to_plot, basis=basis)
plt.title('Amplitude of the full ptychography result')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_obj_amplitude')

to_plot = np.abs(results['obj_half_1'][window])
to_plot = to_plot / np.mean(np.abs(to_plot))
to_plot = np.clip(to_plot, a_min=0.85, a_max=1.15)
p.plot_amplitude(to_plot, basis=basis)
plt.title('Amplitude of the first 50\% ptychography result')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_half1_obj_amplitude')

to_plot = np.abs(results['obj_half_2'][window])
to_plot = to_plot / np.mean(np.abs(to_plot))
to_plot = np.clip(to_plot, a_min=0.85, a_max=1.15)
p.plot_amplitude(to_plot, basis=basis)
plt.title('Amplitude of the second 50\% ptychography result')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_half2_obj_amplitude')

to_plot = np.angle(results['obj_full'][window])
to_plot = np.exp(1j*np.clip(to_plot, a_min=-0.1, a_max=0.1))
p.plot_phase(to_plot, basis=basis)
plt.title('Phase of the full ptychography result')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_obj_phase')
    
p.plot_colorized(results['obj_full'][window], basis=basis)
plt.title('Colorized Image of the full ptychography result')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_obj_colorized')

p.plot_real(t.log(t.abs(cdtools.tools.propagators.far_field(
    t.as_tensor(results['obj_full'][window]))))/np.log(10))
plt.title('Log Base 10 Amplitude of FFT of ptychography reconstruction')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_obj_fourier_amplitude')

#p.plot_amplitude(results['illumination_map_full'][window], basis=basis)
#plt.title('Illumination Intensity (photons per pixel, wrong)')


print('Probe basis:', results['original_probe_basis'])
p.plot_amplitude(np.sum(np.abs(results['probe_full'])**2, axis=0),
                 basis=results['original_probe_basis'])
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_probe_full_intensity')
p.plot_amplitude(results['probe_full'][0],basis=results['original_probe_basis'])
plt.title('Amplitude of the full ptycho top probe mode')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_probe_full_amplitude_mode_1')
p.plot_amplitude(results['probe_full'][1],basis=results['original_probe_basis'])
plt.title('Amplitude of the full ptycho second probe mode')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_probe_full_amplitude_mode_2')
p.plot_amplitude(results['probe_full'][2],basis=results['original_probe_basis'])
plt.title('Amplitude of the full ptycho third probe mode')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_probe_full_amplitude_mode_3')
p.plot_colorized(results['probe_full'][0],basis=results['original_probe_basis'])
plt.title('Colorized full ptycho top probe mode')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_probe_full_colorized_mode_1')
p.plot_colorized(results['probe_full'][1],basis=results['original_probe_basis'])
plt.title('Colorized full ptycho second probe mode')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_probe_full_colorized_mode_2')
p.plot_colorized(results['probe_full'][2],basis=results['original_probe_basis'])
plt.title('Colorized full ptycho third probe mode')
if save_figures:
    savefig_plus_im('figs/rcp_ptycho_probe_full_colorized_mode_3')

plt.figure(figsize=(3.75,2.5))
plt.plot(freqs*1e-6, frc, 'k-')
plt.plot(freqs* 1e-6, threshold, 'k--')
plt.xlabel('Spatial Frequency (cycles / um)')
plt.ylabel('FRC')
plt.grid(True)
plt.tight_layout()
plt.figure(figsize=(3.75,2.5))
plt.semilogy(freqs*1e-6, results['ssnr'], 'k-')
plt.plot(freqs*1e-6, 0.4142 + 0*results['ssnr'],'k--')
plt.gca().yaxis.get_minor_locator().set_params(numticks=99, subs=[.2, .4, .6, .8])
plt.grid(True, which='both')
plt.xlabel('Spatial Frequency (cycles / um)')
plt.ylabel('SSNR')
plt.tight_layout()


p1 = results['probe_half_1']
p2 = results['probe_half_2']

probe_shift = cdtools.tools.image_processing.find_shift(
    t.as_tensor(np.sum(np.abs(p1)**2, axis=0)),
    t.as_tensor(np.sum(np.abs(p2)**2, axis=0)))
           
shifted_p2 = np.stack([cdtools.tools.image_processing.sinc_subpixel_shift(t.as_tensor(pr), probe_shift).numpy() for pr in p2])

fft_p1 = cdtools.tools.propagators.far_field(t.as_tensor(p1)).numpy()
fft_p2 = cdtools.tools.propagators.far_field(t.as_tensor(shifted_p2)).numpy()

probe_fft_shift = cdtools.tools.image_processing.find_shift(
    t.as_tensor(np.sum(np.abs(fft_p1)**2, axis=0)),
    t.as_tensor(np.sum(np.abs(fft_p2)**2, axis=0)))

           
shifted_fft_p2 = np.stack([cdtools.tools.image_processing.sinc_subpixel_shift(t.as_tensor(pr), probe_fft_shift).numpy() for pr in fft_p2])

final_p2 = cdtools.tools.propagators.inverse_far_field(t.as_tensor(shifted_fft_p2)).numpy()



freqs, pcfrc = calc_pcfrc(t.as_tensor(p1), t.as_tensor(final_p2), 40)

freqs *= 1 / (np.abs(results['original_probe_basis'][0,1]))

pcmse = calc_pcmse(t.as_tensor(p1), t.as_tensor(final_p2))
print('PCMSE between probes:', pcmse.cpu().numpy())
sqrtfid = calc_sqrt_fidelity(t.as_tensor(p1), t.as_tensor(final_p2))
norm_pcmse = 1 - sqrtfid**2 / (np.sum(np.abs(p1)**2)*np.sum(np.abs(p2)**2))
print('Normalized PCSME between probes:', norm_pcmse.cpu().numpy())
print('sqrt fidelity between probes:', sqrtfid.cpu().numpy())


plt.figure(figsize=(3.75,2.5))
plt.plot(freqs*1e-6, pcfrc, 'k-')
plt.xlabel('Spatial Frequency (cycles / um)')
plt.ylabel('PCFRC')
plt.xlim([-0.25, 5])
plt.grid(True)
plt.tight_layout()
if save_figures:
    plt.savefig('figs/rcp_probe_pcfrc.pdf')
    
plt.show()
